{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# K-Means Example\n",
    "\n",
    "Implement K-Means algorithm with TensorFlow, and apply it to classify\n",
    "handwritten digit images. This example is using the MNIST database of\n",
    "handwritten digits as training samples (http://yann.lecun.com/exdb/mnist/).\n",
    "\n",
    "Note: This example requires TensorFlow v1.1.0 or over.\n",
    "\n",
    "- Author: Aymeric Damien\n",
    "- Project: https://github.com/aymericdamien/TensorFlow-Examples/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.contrib.factorization import KMeans\n",
    "\n",
    "# Ignore all GPUs, tf random forest does not benefit from it.\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /tmp/data/train-images-idx3-ubyte.gz\n",
      "Extracting /tmp/data/train-labels-idx1-ubyte.gz\n",
      "Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n",
      "Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "# Import MNIST data\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)\n",
    "full_data_x = mnist.train.images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "num_steps = 50 # Total steps to train\n",
    "batch_size = 1024 # The number of samples per batch\n",
    "k = 25 # The number of clusters\n",
    "num_classes = 10 # The 10 digits\n",
    "num_features = 784 # Each image is 28x28 pixels\n",
    "\n",
    "# Input images\n",
    "X = tf.placeholder(tf.float32, shape=[None, num_features])\n",
    "# Labels (for assigning a label to a centroid and testing)\n",
    "Y = tf.placeholder(tf.float32, shape=[None, num_classes])\n",
    "\n",
    "# K-Means Parameters\n",
    "kmeans = KMeans(inputs=X, num_clusters=k, distance_metric='cosine',\n",
    "                use_mini_batch=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Build KMeans graph\n",
    "(all_scores, cluster_idx, scores, cluster_centers_initialized, \n",
    " cluster_centers_vars,init_op,train_op) = kmeans.training_graph()\n",
    "cluster_idx = cluster_idx[0] # fix for cluster_idx being a tuple\n",
    "avg_distance = tf.reduce_mean(scores)\n",
    "\n",
    "# Initialize the variables (i.e. assign their default value)\n",
    "init_vars = tf.global_variables_initializer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 1, Avg Distance: 0.341471\n",
      "Step 10, Avg Distance: 0.221609\n",
      "Step 20, Avg Distance: 0.220328\n",
      "Step 30, Avg Distance: 0.219776\n",
      "Step 40, Avg Distance: 0.219419\n",
      "Step 50, Avg Distance: 0.219154\n"
     ]
    }
   ],
   "source": [
    "# Start TensorFlow session\n",
    "sess = tf.Session()\n",
    "\n",
    "# Run the initializer\n",
    "sess.run(init_vars, feed_dict={X: full_data_x})\n",
    "sess.run(init_op, feed_dict={X: full_data_x})\n",
    "\n",
    "# Training\n",
    "for i in range(1, num_steps + 1):\n",
    "    _, d, idx = sess.run([train_op, avg_distance, cluster_idx],\n",
    "                         feed_dict={X: full_data_x})\n",
    "    if i % 10 == 0 or i == 1:\n",
    "        print(\"Step %i, Avg Distance: %f\" % (i, d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy: 0.7127\n"
     ]
    }
   ],
   "source": [
    "# Assign a label to each centroid\n",
    "# Count total number of labels per centroid, using the label of each training\n",
    "# sample to their closest centroid (given by 'idx')\n",
    "counts = np.zeros(shape=(k, num_classes))\n",
    "for i in range(len(idx)):\n",
    "    counts[idx[i]] += mnist.train.labels[i]\n",
    "# Assign the most frequent label to the centroid\n",
    "labels_map = [np.argmax(c) for c in counts]\n",
    "labels_map = tf.convert_to_tensor(labels_map)\n",
    "\n",
    "# Evaluation ops\n",
    "# Lookup: centroid_id -> label\n",
    "cluster_label = tf.nn.embedding_lookup(labels_map, cluster_idx)\n",
    "# Compute accuracy\n",
    "correct_prediction = tf.equal(cluster_label, tf.cast(tf.argmax(Y, 1), tf.int32))\n",
    "accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "\n",
    "# Test Model\n",
    "test_x, test_y = mnist.test.images, mnist.test.labels\n",
    "print(\"Test Accuracy:\", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
